{
 "cells": [
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:53:20.457124300Z",
     "start_time": "2024-05-08T00:53:19.648425400Z"
    }
   },
   "id": "3b06a5e1bbcc51ab",
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "source": [
    "# GAT layer"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "c6cc22006b078da8"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "class GATLayer(nn.Module):\n",
    "    def __init__(self, G, feature_attn_size, dropout, slope, device):\n",
    "        super(GATLayer, self).__init__()\n",
    "        self.device = device\n",
    "        self.disease_nodes = G.filter_nodes(lambda nodes: nodes.data['type'] == 1)\n",
    "        self.metabolite_nodes = G.filter_nodes(lambda nodes: nodes.data['type'] == 0)\n",
    "        self.G = G\n",
    "        self.slope = slope\n",
    "        self.me_fc = nn.Linear(G.ndata['me_sim'].shape[1], feature_attn_size, bias=False)\n",
    "        self.d_fc = nn.Linear(G.ndata['d_sim'].shape[1], feature_attn_size, bias=False)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        gain = nn.init.calculate_gain('relu')\n",
    "        nn.init.xavier_normal_(self.me_fc.weight, gain=gain)\n",
    "        nn.init.xavier_normal_(self.d_fc.weight, gain=gain)\n",
    "\n",
    "    # 边的注意力系数\n",
    "    def edge_attention(self, edges):\n",
    "        '''通过逐元素相乘的方式计算边注意力系数，a是每条边的注意力系数'''\n",
    "        a = torch.sum(edges.src['z'].mul(edges.dst['z']), dim=1).unsqueeze(1)\n",
    "        return {'e': F.leaky_relu(a, negative_slope=self.slope)}\n",
    "\n",
    "    def message_func(self, edges):\n",
    "        return {'z': edges.src['z'], 'e': edges.data['e']}\n",
    "\n",
    "    def reduce_func(self, nodes):\n",
    "        alpha = F.softmax(nodes.mailbox['e'], dim=1)\n",
    "        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)\n",
    "        return {'h': F.elu(h)}\n",
    "\n",
    "    def forward(self, G):\n",
    "        self.G.apply_nodes(lambda nodes: {'z': self.dropout(self.d_fc(nodes.data['d_sim']))}, self.disease_nodes)\n",
    "        self.G.apply_nodes(lambda nodes: {'z': self.dropout(self.me_fc(nodes.data['me_sim']))}, self.metabolite_nodes)\n",
    "        self.G.apply_edges(self.edge_attention)\n",
    "        self.G.update_all(self.message_func, self.reduce_func)\n",
    "        return self.G.ndata.pop('h')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:53:20.478126500Z",
     "start_time": "2024-05-08T00:53:20.466144Z"
    }
   },
   "id": "10a24f623192e6da",
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "source": [
    "# MultiHead attention layer"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "3876f32b793e58bf"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "class MultiHeadGATLayer(nn.Module):  # 多头注意力\n",
    "    def __init__(self, G, feature_attn_size, num_heads, dropout, slope,device, merge='cat'):\n",
    "        super(MultiHeadGATLayer, self).__init__()\n",
    "        self.device = device\n",
    "        self.G = G\n",
    "        self.dropout = dropout\n",
    "        self.slope = slope\n",
    "        self.merge = merge\n",
    "        self.heads = nn.ModuleList()\n",
    "        for i in range(num_heads):\n",
    "            self.heads.append(GATLayer(G, feature_attn_size, dropout, slope, device))\n",
    "\n",
    "    def forward(self, G):\n",
    "        head_outs = [attn_head(G) for attn_head in self.heads]\n",
    "        if self.merge == 'cat':\n",
    "            return torch.cat(head_outs, dim=1).to(self.device)\n",
    "        else:\n",
    "            return torch.mean(torch.stack(head_outs), dim=0).to(self.device)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:53:20.496122400Z",
     "start_time": "2024-05-08T00:53:20.477123800Z"
    }
   },
   "id": "a4e4fa0ba59f3709",
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Node-level attention with metapath length 2"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "1764e5865bcf8ace"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "class HAN_metapath_specific(nn.Module):\n",
    "    def __init__(self, G, feature_attn_size, out_dim, dropout, slope, device):\n",
    "        super(HAN_metapath_specific, self).__init__()\n",
    "        self.device = device\n",
    "        self.metabolite_nodes = G.filter_nodes(lambda nodes: nodes.data['type'] == 0)\n",
    "        self.microbe_nodes = G.filter_nodes(lambda nodes: nodes.data['type'] == 2)\n",
    "        self.disease_nodes = G.filter_nodes(lambda nodes: nodes.data['type'] == 1)\n",
    "        self.G = G\n",
    "        self.slope = slope\n",
    "        self.me_fc = nn.Linear(G.ndata['me_sim'].shape[1], feature_attn_size, bias=False)  # 统一节点特征维数\n",
    "        self.mi_fc = nn.Linear(G.ndata['mi_sim'].shape[1], feature_attn_size, bias=False)\n",
    "        self.d_fc = nn.Linear(G.ndata['d_sim'].shape[1], feature_attn_size, bias=False)\n",
    "        self.me_fc1 = nn.Linear(feature_attn_size + 495, out_dim)   # 设置全连接层\n",
    "        self.d_fc1 = nn.Linear(feature_attn_size + 383, out_dim)\n",
    "        self.attn_fc = nn.Linear(feature_attn_size * 2, 1, bias=False)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        gain = nn.init.calculate_gain('relu')\n",
    "        nn.init.xavier_normal_(self.me_fc.weight, gain=gain)\n",
    "        nn.init.xavier_normal_(self.mi_fc.weight, gain=gain)\n",
    "        nn.init.xavier_normal_(self.d_fc.weight, gain=gain)\n",
    "\n",
    "    def edge_attention(self, edges):\n",
    "        a = torch.sum(edges.src['z'].mul(edges.dst['z']), dim=1).unsqueeze(1)\n",
    "        '''z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)\n",
    "        a = self.attn_fc(z2)'''\n",
    "        return {'e': F.leaky_relu(a, negative_slope=self.slope)}\n",
    "\n",
    "    def message_func(self, edges):\n",
    "        return {'z': edges.src['z'], 'e': edges.data['e']}\n",
    "\n",
    "    def reduce_func(self, nodes):\n",
    "        alpha = F.softmax(nodes.mailbox['e'], dim=1)\n",
    "        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)\n",
    "        return {'h': F.elu(h)}\n",
    "\n",
    "    def forward(self, new_g, meta_path):\n",
    "        if meta_path == 'dmi':\n",
    "            new_g.apply_nodes(lambda nodes: {'z': self.dropout(self.d_fc(nodes.data['d_sim']))}, self.disease_nodes)\n",
    "            new_g.apply_nodes(lambda nodes: {'z': self.dropout(self.mi_fc(nodes.data['mi_sim']))}, self.microbe_nodes)\n",
    "            new_g.apply_edges(self.edge_attention)\n",
    "            new_g.update_all(self.message_func, self.reduce_func)\n",
    "            h_dmi = new_g.ndata.pop('h').to(self.device)\n",
    "            return h_dmi\n",
    "\n",
    "        elif meta_path == 'mime':\n",
    "            new_g.apply_nodes(lambda nodes: {'z': self.dropout(self.mi_fc(nodes.data['mi_sim']))}, self.microbe_nodes)\n",
    "            new_g.apply_nodes(lambda nodes: {'z': self.dropout(self.me_fc(nodes.data['me_sim']))}, self.metabolite_nodes)\n",
    "            new_g.apply_edges(self.edge_attention)\n",
    "            new_g.update_all(self.message_func, self.reduce_func)\n",
    "            h_mime = new_g.ndata.pop('h').to(self.device)\n",
    "            return h_mime"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:53:20.521126900Z",
     "start_time": "2024-05-08T00:53:20.505134300Z"
    }
   },
   "id": "cd5c390264dbd117",
   "execution_count": 4
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
